import config
from pymilvus import connections, FieldSchema, DataType, CollectionSchema, Collection, utility

class Milvus:
    def __init__(self):
        super(Milvus, self).__init__()

        vdbs = config.get_config().models.vdbs
        for v in vdbs:
            if v.name == "milvus" and v.enable == 1:
                self._db = v

        self._connections = connections.connect(host=self._db.host, port=self._db.port, user=self._db.user, password=self._db.password)

    def create_collection(self, dim):
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
            FieldSchema(name="question", dtype=DataType.VARCHAR, max_length=500),
            FieldSchema(name="answer", dtype=DataType.VARCHAR, max_length=3000),
            FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=200),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
        ]
        schema = CollectionSchema(fields=fields, description="", enable_dynamic_field=True)
        collection = Collection(name=self._db.collection_name, schema=schema, using="default")
        index_params = {
            'index_type': 'IVF_FLAT', 'metric_type': "L2", 'params': {'nlist': 1}
        }
        collection.create_index(field_name="embedding", index_params=index_params)
        collection.load()

    def delete_collection(self):
        if utility.has_collection(collection_name=self._db.collection_name):
            utility.drop_collection(collection_name=self._db.collection_name)

    def upsert_document(self, entities):
        if utility.has_collection(collection_name=self._db.collection_name) is False:
            return False
        
        collection = Collection(name=self._db.collection_name, using="default")
        collection.insert(entities)
        collection.flush()
        return True
    
    def search_document(self, data, top_k):
        if utility.has_collection(collection_name=self._db.collection_name) is False:
            return []
        collection = Collection(name=self._db.collection_name, using="default")
        documents = collection.search(
            data=data, anns_field="embedding", param={
                "metric_type": "L2", "params": { "nprobe": 1 }, "offset": 0
            }, limit=top_k, output_fields=["*"]   
        )
        document_list = []
        for i in range(len(documents[0])):
            # get column name
            fields = documents[0][i].entity.fields
            document = {'id': documents[0][i].id, 'distance': documents[0][i].distance}
            for f in fields:
                document[f] = documents[0][i].entity.get(f)
            document_list.append(document)
        return document_list